{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "OpenAI使用`tiktoken`来拆分文本为token。该notebook介绍OpenAI是如何计数token的。"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "编码方法决定了不同的文本拆分Token的方式。OpenAI使用如下3个`tiktoken`支持的编码方法于不同的模型中：\n",
    "\n",
    "1. cl100k_base: gpt-4, gpt-3.5-turbo, text-embedding-ada-002\n",
    "2. p50k_base: text-davinci-002, text-davinci-003\n",
    "3. r50k_base 或 gpt2: GPT-3模型，如davinci"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. 安装`tiktoken`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install --upgrade tiktoken > /dev/null"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "2. 编码"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tiktoken\n",
    "\n",
    "encoding = tiktoken.get_encoding(\"p50k_base\")\n",
    "encoding_for_model = tiktoken.encoding_for_model(\"gpt-4\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[19526, 254, 25001, 121, 171, 120, 234, 17312, 233, 20998, 233]\n",
      "[57668, 53901, 3922, 4916, 233, 98915]\n"
     ]
    }
   ],
   "source": [
    "text_chinese = '你好，朋友'\n",
    "\n",
    "print(encoding.encode(text_chinese))\n",
    "print(encoding_for_model.encode(text_chinese))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "3. 解码"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "你好，朋友\n",
      "你好，朋友\n"
     ]
    }
   ],
   "source": [
    "print(encoding.decode([19526, 254, 25001, 121, 171, 120, 234, 17312, 233, 20998, 233]))\n",
    "print(encoding_for_model.decode([57668, 53901, 3922, 4916, 233, 98915]))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "4. OpenAI的Chat API的Token计数方式，参考官方文档[链接](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def num_tokens_from_messages(messages, model=\"gpt-3.5-turbo-0301\"):\n",
    "    \"\"\"Returns the number of tokens used by a list of messages.\"\"\"\n",
    "    try:\n",
    "        encoding = tiktoken.encoding_for_model(model)\n",
    "    except KeyError:\n",
    "        print(\"Warning: model not found. Using cl100k_base encoding.\")\n",
    "        encoding = tiktoken.get_encoding(\"cl100k_base\")\n",
    "    if model == \"gpt-3.5-turbo\":\n",
    "        print(\"Warning: gpt-3.5-turbo may change over time. Returning num tokens assuming gpt-3.5-turbo-0301.\")\n",
    "        return num_tokens_from_messages(messages, model=\"gpt-3.5-turbo-0301\")\n",
    "    elif model == \"gpt-4\":\n",
    "        print(\"Warning: gpt-4 may change over time. Returning num tokens assuming gpt-4-0314.\")\n",
    "        return num_tokens_from_messages(messages, model=\"gpt-4-0314\")\n",
    "    elif model == \"gpt-3.5-turbo-0301\":\n",
    "        tokens_per_message = 4  # every message follows <|start|>{role/name}\\n{content}<|end|>\\n\n",
    "        tokens_per_name = -1  # if there's a name, the role is omitted\n",
    "    elif model == \"gpt-4-0314\":\n",
    "        tokens_per_message = 3\n",
    "        tokens_per_name = 1\n",
    "    else:\n",
    "        raise NotImplementedError(f\"\"\"num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens.\"\"\")\n",
    "    num_tokens = 0\n",
    "    for message in messages:\n",
    "        num_tokens += tokens_per_message\n",
    "        for key, value in message.items():\n",
    "            num_tokens += len(encoding.encode(value))\n",
    "            if key == \"name\":\n",
    "                num_tokens += tokens_per_name\n",
    "    num_tokens += 3  # every reply is primed with <|start|>assistant<|message|>\n",
    "    return num_tokens"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "5. 示例代码"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-3.5-turbo-0301\n",
      "67 prompt tokens counted by num_tokens_from_messages().\n",
      "67 prompt tokens counted by the OpenAI API.\n",
      "\n",
      "gpt-4-0314\n",
      "67 prompt tokens counted by num_tokens_from_messages().\n",
      "67 prompt tokens counted by the OpenAI API.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import openai\n",
    "\n",
    "example_messages = [\n",
    "    {\n",
    "        \"role\": \"system\",\n",
    "        \"content\": \"你是翻译助理，请帮我将英文翻译成中文，谢谢。请只回复翻译文字，不要回复其他内容。\",\n",
    "    },\n",
    "    {\n",
    "        \"role\": \"user\",\n",
    "        \"name\": \"Alice\",\n",
    "        \"content\": \"The sky is blue.\",\n",
    "    },\n",
    "]\n",
    "\n",
    "for model in [\"gpt-3.5-turbo-0301\", \"gpt-4-0314\"]:\n",
    "    print(model)\n",
    "    # 来自上述实现的函数的token计数\n",
    "    print(f\"{num_tokens_from_messages(example_messages, model)} prompt tokens counted by num_tokens_from_messages().\")\n",
    "    # 来自OpenAI API的token计数\n",
    "    response = openai.ChatCompletion.create(\n",
    "        model=model,\n",
    "        messages=example_messages,\n",
    "        temperature=0,\n",
    "        max_tokens=1  # 仅返回用于计数的token数量，因此不需要API返回completion内容\n",
    "    )\n",
    "    print(f'{response[\"usage\"][\"prompt_tokens\"]} prompt tokens counted by the OpenAI API.')\n",
    "    print()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
